{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "db53d67f-cef1-4fb9-aad0-771d0c09bf77",
   "metadata": {},
   "source": [
    "# 千帆与 DeepLake"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f5ee127-21a8-4d2b-9b52-aa529da186c4",
   "metadata": {},
   "source": [
    "Deep Lake 是一个 AI 领域的向量数据库，存储了大量数据集和向量，可以用于大模型应用的开发。\n",
    "\n",
    "本文将介绍如何利用千帆 SDK 和 DeepLake 实现检索式问答。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17809625",
   "metadata": {},
   "source": [
    "本文使用的千帆SDK版本为\n",
    "````\n",
    "qianfan>=0.1.4\n",
    "````"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b8c37ee-9f6d-42e7-9f58-aabbe6c168d6",
   "metadata": {},
   "source": [
    "## 1. 准备"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "416481c1-2ad8-4679-a6c0-4b2bd5dbac25",
   "metadata": {},
   "source": [
    "为了能够使用 DeepLake，我们需要先安装 DeepLake 的库。\n",
    "本文使用的DeepLake版本为：\n",
    "```\n",
    "deeplake==3.8.6\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9de769f2-91a6-4634-9dc1-84872cd6890d",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install deeplake"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed1715f7-e524-4cfb-ad20-464667e4686b",
   "metadata": {},
   "source": [
    "## 2. 鉴权"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "153f5e06-7011-4da1-aacd-24a07c69bc9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import qianfan\n",
    "\n",
    "# 初始化千帆 SDK\n",
    "os.environ[\"QIANFAN_AK\"] = \"your_ak\"\n",
    "os.environ[\"QIANFAN_SK\"] = \"your_sk\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "841ca480-bce5-4c86-a499-861b8a4935ca",
   "metadata": {},
   "source": [
    "## 3. 准备数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c4e00e78-9c73-4cc2-904d-58f2dd18a836",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\\"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Opening dataset in read-only mode as you don't have write permissions.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\\"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/cohere-wikipedia-22-sample\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "-"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hub://activeloop/cohere-wikipedia-22-sample loaded successfully.\n",
      "\n",
      "Dataset(path='hub://activeloop/cohere-wikipedia-22-sample', read_only=True, tensors=['ids', 'metadata', 'text'])\n",
      "\n",
      "  tensor    htype     shape      dtype  compression\n",
      " -------   -------   -------    -------  ------- \n",
      "   ids      text    (20000, 1)    str     None   \n",
      " metadata   json    (20000, 1)    str     None   \n",
      "   text     text    (20000, 1)    str     None   \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " "
     ]
    }
   ],
   "source": [
    "import deeplake\n",
    "\n",
    "# 可以从 DeepLake 上拉取数据集\n",
    "ds = deeplake.load(\"hub://activeloop/cohere-wikipedia-22-sample\")\n",
    "ds.summary()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bc5ad06-b02f-49f0-9e0d-3cd662ba6807",
   "metadata": {},
   "source": [
    "## 4. 构建文本索引"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa2f9f93-b1a0-4ddf-aec8-61b45bb26127",
   "metadata": {},
   "source": [
    "为了之后能够进行检索，我们需要先将数据集中的数据转换成向量，所以这里我们用千帆 Embedding 对数据进行转换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c8f03b20-70a7-4afe-a93a-185b6df5f67d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": []
    }
   ],
   "source": [
    "from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint\n",
    "from langchain.vectorstores import DeepLake\n",
    "\n",
    "dataset_path = 'wikipedia-embeddings-deeplake'\n",
    "embedding = QianfanEmbeddingsEndpoint()\n",
    "db = DeepLake(dataset_path, embedding=embedding, overwrite=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3210baf6-a293-4b42-9d51-43783cc328e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Creating 10 embeddings in 1 batches of size 10:: 100%|██████████| 1/1 [00:00<00:00,  1.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset(path='wikipedia-embeddings-deeplake', tensors=['text', 'metadata', 'embedding', 'id'])\n",
      "\n",
      "  tensor      htype      shape     dtype  compression\n",
      "  -------    -------    -------   -------  ------- \n",
      "   text       text      (10, 1)     str     None   \n",
      " metadata     json      (10, 1)     str     None   \n",
      " embedding  embedding  (10, 384)  float32   None   \n",
      "    id        text      (10, 1)     str     None   \n",
      "Dataset(path='wikipedia-embeddings-deeplake', tensors=['text', 'metadata', 'embedding', 'id'])\n",
      "\n",
      "  tensor      htype      shape     dtype  compression\n",
      "  -------    -------    -------   -------  ------- \n",
      "   text       text      (10, 1)     str     None   \n",
      " metadata     json      (10, 1)     str     None   \n",
      " embedding  embedding  (10, 384)  float32   None   \n",
      "    id        text      (10, 1)     str     None   \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 向数据库中增加数据的向量信息\n",
    "batch_size = 100\n",
    "\n",
    "nsamples = 10  # 这里仅用作测试，仅索引少量文本\n",
    "for i in range(0, nsamples, batch_size):\n",
    "    # find end of batch\n",
    "    i_end = min(nsamples, i + batch_size)\n",
    "\n",
    "    batch = ds[i:i_end]\n",
    "    id_batch = batch.ids.data()[\"value\"]\n",
    "    text_batch = batch.text.data()[\"value\"]\n",
    "    meta_batch = batch.metadata.data()[\"value\"]\n",
    "\n",
    "    db.add_texts(text_batch, metadatas=meta_batch, ids=id_batch)\n",
    "\n",
    "db.vectorstore.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4267fb51-7116-48ba-b004-79241cceeba1",
   "metadata": {},
   "source": [
    "## 5. 实现检索式 QA"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d4700dd-dfce-4e46-a284-0d972d327117",
   "metadata": {},
   "source": [
    "这里我们通过 LangChain 中的实现帮助我们快速实现该功能，LangChain 帮助我们封装了检索相关文本的流程，这里我们只需要调整最终向模型进行请求时的 prompt 模版。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ce40750f-a41a-4f4f-b16d-02c57dae5e52",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 设置 prompt 模版\n",
    "from langchain.prompts.chat import (\n",
    "    ChatPromptTemplate,\n",
    "    HumanMessagePromptTemplate,\n",
    ")\n",
    "\n",
    "template = \"\"\"Use the following pieces of context to answer the user's question. \n",
    "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n",
    "----------------\n",
    "{context}\n",
    "----------------\n",
    "Now answer the question:\n",
    "{question}\n",
    "\"\"\"\n",
    "\n",
    "messages = [\n",
    "    HumanMessagePromptTemplate.from_template(template),\n",
    "]\n",
    "CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c81e1258-387e-4a1d-bcec-a5b1e0d63c6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chains import RetrievalQA\n",
    "from langchain.chat_models import QianfanChatEndpoint\n",
    "\n",
    "qa = RetrievalQA.from_chain_type(\n",
    "    llm=QianfanChatEndpoint(model='ERNIE-Bot'),\n",
    "    chain_type_kwargs={\n",
    "        \"prompt\": CHAT_PROMPT\n",
    "    },\n",
    "    chain_type=\"stuff\", \n",
    "    retriever=db.as_retriever(), \n",
    "    verbose=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dea84e72-3de1-4518-a095-c7950ecddfde",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new RetrievalQA chain...\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'The military does not say 24:00 because they do not like to have two names for the same thing. Instead, they always say \"23:59\", which is one minute before midnight. This is because using \"24:00\" could potentially confuse people who are familiar with the 12-hour clock format.'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "query = 'Why does the military not say 24:00?'\n",
    "qa.run(query)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9657fb8b-0c9f-40ff-aed1-7e7c88c5f10e",
   "metadata": {},
   "source": [
    "搞定！"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
